{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EAsv3uyVZERg"
      },
      "source": [
        "# Quantization Aware Training (QAT)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/quantization_aware_training.ipynb)\n",
        "\n",
        "This is an example on how to obtain and run quantized versions of Gemma models. It's best to first read the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) colab to understand this one.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "executionInfo": {
          "elapsed": 3,
          "status": "ok",
          "timestamp": 1741556668372,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "7EM9c9_kZERg"
      },
      "outputs": [],
      "source": [
        "!pip install -q gemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "executionInfo": {
          "elapsed": 5222,
          "status": "ok",
          "timestamp": 1741556673717,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "j85GsCFMZERg"
      },
      "outputs": [],
      "source": [
        "# Common imports\n",
        "import os\n",
        "import treescope\n",
        "import optax\n",
        "\n",
        "# Gemma imports\n",
        "from kauldron import kd\n",
        "from gemma import gm\n",
        "from gemma import peft"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Chiw4WkZERg"
      },
      "source": [
        "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "executionInfo": {
          "elapsed": 52,
          "status": "ok",
          "timestamp": 1741556673898,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "yh5RMMLy4DR1"
      },
      "outputs": [],
      "source": [
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-apwpcjNZERg"
      },
      "source": [
        "## Config updates\n",
        "\n",
        "If you're familiar with the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) tutorial, switching to QAT only require 1 change to the trainer.\n",
        "\n",
        "This is slightly different to LoRA (we discuss the difference below)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Id0PQE16ZERg"
      },
      "source": [
        "### 1. Model\n",
        "\n",
        "Wrap the model in the `gm.nn.QuantizationAwareTrainingWrapper`. This will apply model surgery to replace all the linear and compatible layers with Simulation for Quantized layers. You can choose among several options for quantization:\n",
        "\n",
        "* SFP8: switched floating point in 8 bits (very efficient with gemma.cpp)\n",
        "* Q4_0: per-block integer quantization (equivalent to 4.5 bits per weights), very popular on llama.cpp\n",
        "* INT4: per-channel weight quantization (almost exactly 4 bits per weights)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 9380,
          "status": "ok",
          "timestamp": 1741556683410,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "rKHZ6d3-ZERg"
      },
      "outputs": [],
      "source": [
        "model = gm.nn.QuantizationAwareWrapper(\n",
        "    method = peft.QuantizationMethod.INT8,\n",
        "    model=gm.nn.Gemma3_4B(tokens=\"batch.input\", text_only=True),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9MBCCglxdwht"
      },
      "source": [
        "Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "executionInfo": {
          "elapsed": 54,
          "status": "ok",
          "timestamp": 1741556683592,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "lS_OkUkWgsvP"
      },
      "outputs": [],
      "source": [
        "init_transform = wrapped=gm.ckpts.LoadCheckpoint(\n",
        "    path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O9kiQmjgdwht"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ybKFzaE_dwht"
      },
      "source": [
        "### Data pipeline\n",
        "\n",
        "Like for the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example, we recreate the tokenizer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 822,
          "status": "ok",
          "timestamp": 1741556684537,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "u5MZYuTBdwht",
        "outputId": "7eaa77fc-179e-4f8a-cfa7-cb5cd63e2597"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[\u003c_Gemma3SpecialTokens.BOS: 2\u003e, 2094, 563, 614, 2591, 13315]"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tokenizer = gm.text.Gemma3Tokenizer()\n",
        "\n",
        "tokenizer.encode('This is an example sentence', add_bos=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0QNkBrD8dwht"
      },
      "source": [
        "And the data pipeline:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "height": 394
        },
        "executionInfo": {
          "elapsed": 17257,
          "status": "ok",
          "timestamp": 1741556701922,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "ULN7oQ_8dwht",
        "outputId": "781108d4-eb01-479e-8028-96320966967c"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\u003cscript\u003e (()=\u003e{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child =\u003e { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=\u003e{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })(); \u003c/script\u003e \u003ctreescope-container class=\"treescope_out_71880ff0037841b3a5a31e2d4320694e\" \u003e\u003c/treescope-container\u003e \u003ctreescope-run-here\u003e\u003cscript type=\"application/octet-stream\"\u003e const root = ( Array.from(document.getElementsByClassName( \"treescope_out_71880ff0037841b3a5a31e2d4320694e\")) .filter((elt) =\u003e !elt.dataset.setup) )[0]; root.dataset.setup = 1; const msg = document.createElement(\"span\"); msg.style = \"color: #cccccc; font-family: monospace;\"; msg.textContent = \"(Loading...)\"; root.state.loadingMsg = msg; root.shadowRoot.appendChild(msg); root.state.chain = new Promise((resolve, reject) =\u003e { const observer = new IntersectionObserver((entries) =\u003e { for (const entry of entries) { if (entry.isIntersecting) { resolve(); observer.disconnect(); return; } } }, {rootMargin: \"1000px\"}); window.setTimeout(() =\u003e { observer.observe(root); }, 0); }); root.state.deferring = false; const _insertNode = (node) =\u003e { for (let oldScript of node.querySelectorAll(\"script\")) { let newScript = document.createElement(\"script\"); newScript.type = oldScript.type; newScript.textContent = oldScript.textContent; oldScript.parentNode.replaceChild(newScript, oldScript); } if (root.state.loadingMsg) { root.state.loadingMsg.remove(); root.state.loadingMsg = null; } root.shadowRoot.appendChild(node); }; root.defns.insertContent = ((contentNode, compressed) =\u003e { if (compressed) { root.state.deferring = true; } if (root.state.deferring) { root.state.chain = (async () =\u003e { await root.state.chain; if (compressed) { const encoded = contentNode.textContent; const blob = new Blob([ Uint8Array.from(atob(encoded), (m) =\u003e m.codePointAt(0)) ]); const reader = blob.stream().pipeThrough( new DecompressionStream(\"deflate\") ).pipeThrough( new TextDecoderStream(\"utf-8\") ).getReader(); const parts = []; while (true) { const step = await reader.read(); if (step.done) { break; } parts.push(step.value); } const tpl = document.createElement('template'); tpl.innerHTML = parts.join(\"\"); _insertNode(tpl.content); } else { _insertNode(contentNode.content); } })(); } else { _insertNode(contentNode.content); } }); \u003c/script\u003e\u003c/treescope-run-here\u003e\u003cdiv style=\"display:none\"\u003e \u003cscript type=\"application/octet-stream\" \u003eeNrtGOtW2zj6/z6Fmp4zJAsxuZAAoXDWCbnRAoXQQun0ZBRbtkUcychyQpjDA8x77L7YPMl+knNPoO2W7XZnJjknjqXvru+qV6Ec+uTAkIKQ0OIBaQvOJfoVBTykknJWQoL4WNI+2UMOZzLt4B71hyXU44yHAbZgfeBRSdL6pYQCASs+DWVak07LYQCrjDNY7mCr6woeMTttcZ+LUoy6h0ZvHR8AgB61pVdCDpUAxiRhcg8F2LYpc9M+cWQJ5SxPMWEk7RHqerCSNQqKDJOYgswTtNGfdJ+GtEN9KkFyHEk+gU1TJgVlIbXSIb0n8e5I3IdXm7F5Xk3MkxYRA54C1kJL0EAipd/+Gg4Cn1pYWWyTW5Io7QXBvbWDZDK1fwAGBX6hRDZxWIj2kfRoaLhEnoO1T7hNkinD46E09D6oRiRqB4QplU1LUVVIHz+t2mlgZvsEtlnk+3sxBwPEbHHOYDU54KKbQrMy8EtYUltzy5JaajEgwuGih5lFDMYHyZQ+X2CQXNpB6RjpFcrnUkCHOii5ILXhE+ZKD+3vo4wCeVJ0QWQkGNgdET8kU8G8iCnJFkmHHnWkkk8DqD8P8H2EQxK8itl8YAhyG5FQmoz29HHVBO6RZGyTlKKxt8QoiEIvNuPeCh3HLPZjNZ7Q8stlUFLEBym56/pxVLZ15IC3BoqWWiG+3ECkDw4+OkklnX43umSojJ4QCSXQCNiwfByGbyA4R3STiQnNdg/cMDFm/pACe4L7ax8/eLW5KgBs2kea4H5iPn0kkMQd0JTc7ScyCQhdIZdBOAMRwRgMtp4KhtUWSCqcse4JCMY4jek80sadjiB97T86rbws7uRwJgNajQAs3usBYnvgEdYmdwEcELFLjMuk4XDfxh3gx0CQkofD5IGPO8Q/mN9px1JpnJLlEatL7FQK/T01wxXrz5RrqBKNOwuQzxayBQUwR/uR7PuAVooA0DYNAx8Px1l2ERAdIK1CqdQhEMBkRgJLf/ZW8oszbTqrUu0oI4MJJ7wo0/m343OVsh/l6fE+EcucbSy6IcEuOAVbxn6mo5jIoFBXI43h5yTUJaOE1n7OFTrW2v9SvHmkR4Usfgch1TkqxpEI1QEGHAonESv40vD52OpQ0IzSOtTDx3z8ebhO1ZPkTi5zMWjYdqgIZZuztnL/FaH1VCgZuYKKppVHhb5Z/PjEF0VUWvWwcKHVicXQAf3wjdwgpQXDTiQl9BirEtB0e5XTJlBiAQoMCf3jauCfSXbLTiy0n2vHGLyCYh+1hr0O90N0Gkmlr40qMSY8gyEERnpAOl1oJTV62IMy4kEahkaPSUCnOCT2pAF9STLqu7fs5jG27hAzxi7pLWoZx8cKLVanuymmMcBh24LmEQw7wceOjKNtIU8/xXMBZ57lrOlRH4tkOm1jidOYwcHqHiQ1u6yYqMZKYDb2Zk0WZUNEwGLQNad5JL9OlYkEcDCU2C/mJdEs0QvaC7iQmC3R7gjehWqtVqbJ6PPWnUGbsef4mB8M1aaAYHbbgibXFoSNRJ0fI4DmPOCzNQuT0BkV0rlQtbBvJWHWgUY7G9zp3swIJVb4E3n/a5J0uLAh78aS2FyC7kqKWePBrNfXfZTAAVBYngMfoEnHAqQdYMEg6trjrD4+CMfBVja/AjCgiuJ46hOjIU+lrpGFRkvpjKFz6nS0LOnBEIu0K7BNQbwkyuYLNnE3EAd/dgnKQAgXLW8j9m9oT1W20EtoZOMlWZbS6vPkbLSUncf6PBi6OQU7+TiAFPX5Nu/rc/jjHOLmTvOY65EfgXkOOVaxeMYxXE1G6IUpBB4ajuA9mLasSA0BhsrsodHHPoxkyVTKCDnMYjrfq5lKPY246Kh56gvLTmINTjY1mWBDj8DYDvMhGaBKq9VS2rTUmhpa9SaMhKCyRVpDZiV/+ceo1FmqhPyHZS8uWIqTsiOM7f5obTC6MdlSE1EorBKKhJ9UObik9jcH3HFyex3I8sWtDTuzWz92zbKpP80z0+T6X/l8AL+NmmlWzac+5Z5pul3+2m5Wy5XBB9O8+FA5Mo+b5YpZc++ajTeeDMvHlLj52uFV7k2z+KHfCiL69rhwkT26ap6/P+5fHt/Lt8NarbJ+6XYvaPkw49HDs+ioatdvMo3OptNv2sHt66J3e0npWXTM6l7DeSfNd8Xyidgya03WrRatd1HE1s8Lt1bYHfSdmr95e+dW+Y7bORrUd7INc5OZ54U3Qhxlz9fd+8y5nTGPnKx7sl0Z1G9yboYPo/Pt7V41Wxw0rnZPXTcgF93hFml27gtWR5zWJTbds+bJ4BCHw/AsajavLqu1gfn2LGh+sN9tbq672xfbV3mZcV6/vTX7BaD5xjzZNo8HZs+9P2+tR9ctUr26yzlF6/5k67wxLERl8/V9+SaoBXnaOKtUM9fR263WNnPKb6qN2nHPpOs7/WrOY1lve73zfnB1M2iI/mH9XYXdONWqK9dPrWvf3y7sVo4G5R1vd+v4uN7K169Nt9cs3JTPduVFnTR2q+Vys54/dLfONz9Yw45ZhzN9/3rTPKtjkxxXfLNxXz11r6VbLL91T0+bh+UuPSuQWvmqUq5ZNBN4ggcMfCO4rh5m77PdllNxpDd8zRo2roUNJ3PSq1dPimXbvH3/PsAybF33bBvT3Zxzv7v1jt7cFoOeKJ7yD5UWFfVe/6ieb1228rVqziqfORfrDZ8H9a1aOChg97a4Q69J68QPLlm50ST2sSDR5W290ste1kS31bor5IqXl+HABIlSSF9ZyeSadus1ld9/gZ9J9GObB1DZpiGpL9oMw3gCYiOO2U9A6+mrC0/f/OjmI+6LgDa4B7NQMm5P5u/lIAQvuApfABu1L2othPSgSKh+TDUxeICpRAz3qYslFwZQDjocC9sYCCjAFzC6JKe0QNkRrenlD1TVZGKmWVPXPsDlgvYIdHXJ8b3gEp4gPei8llAfNlAuk8no2gnJF8poUo8dq/nOdGSJqXBq4BpnMHVTlkAvUQ1THxKb5EgBv9CZDaoli7AP2ZiCzQi2VUO5Pmu70RXWZy6vVHM6vr2abWMS81vzlxqLTQAAx3XuFWVBNCpCCV3ZOvwusZLIqAjCZlwAQT6NfPDrq80RsVkBlvrTxFP7C5uz906Jg59e3uW29zSz+C+YCIAPSuhxlZ9VwxWyxVd1K8VeeUm3ADnfwYCKvhxrtQw3pfESjYBYYDCYHqApAHeSxa3kjnblFPr4+2//ymyg33/7Zy6/3c4Wtz+heyJ4KdvObu2ovkS/bRVyT4vjTsVZeKw46qcU/Vo30DolP35ECOU2EMpmCqBYHgo6gtQGCwhlZn8/bfwNxZ/vgaG2vye/L8b4pDI7Bg8f+0HqkdPTj40nzx59gR/O7q1Nh9Y1xFlFpdf9ta+sL/paPLWGJhP1fuLjmg6+tU8JpLvB/cTMwF1CP91GXO5NgOLXPbR0PQCtp86uUL09+D+S/eAp+3xbnvJ5GLZ7OOx+l1z1/5GjOpz7I9fcQNkU9AMRKeXyGeRgPySQnPLbmR88IX2sKVGnsYk+vzCbLtAX48PKJP7/NEw1gT+f2t+R6WKJUHH4h6gSk3z7uUoxA/jjVAuJBej/V6lY2c7qWvFH6Wg/qjZuNjBVQzf7Dvvbj8b5fGP4yPtcZvkjs1soF39Z9hvZfXVxeCpaNn6Y0hAn18/VhTHUsxaF+PGwuGrT/sG/AQJ3L4U=\u003c/script\u003e \u003ctreescope-run-here\u003e\u003cscript type=\"application/octet-stream\"\u003e const root = ( Array.from(document.getElementsByClassName( \"treescope_out_71880ff0037841b3a5a31e2d4320694e\")) .filter((elt) =\u003e !elt.dataset['step0']) )[0]; root.dataset['step0'] = 1; root.defns.insertContent( this.parentNode.querySelector('script[type=\"application/octet-stream\"]'), true ); this.parentNode.remove(); \u003c/script\u003e\u003c/treescope-run-here\u003e \u003c/div\u003e"
            ],
            "text/plain": [
              "{\n",
              "  'input': # np.ndarray int64(8, 200) [≥0, ≤237_167] zero:1_148 nonzero:452\n",
              "    array([[   2,  105, 2364, ...,    0,    0,    0],\n",
              "           [   2,  105, 2364, ...,    0,    0,    0],\n",
              "           [   2,  105, 2364, ...,    0,    0,    0],\n",
              "           ...,\n",
              "           [   2,  105, 2364, ...,    0,    0,    0],\n",
              "           [   2,  105, 2364, ...,    0,    0,    0],\n",
              "           [   2,  105, 2364, ...,    0,    0,    0]], shape=(8, 200))\n",
              "  ,\n",
              "  'loss_mask': \u003cnp.ndarray bool(8, 200, 1) true:230 false:1_370\u003e,\n",
              "  'target': \u003cnp.ndarray int64(8, 200, 1) [≥0, ≤237_167] zero:1_148 nonzero:452\u003e,\n",
              "}"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "ds = kd.data.py.Tfds(\n",
        "    name='mtnt/en-fr',\n",
        "    split='train',\n",
        "    shuffle=True,\n",
        "    batch_size=8,\n",
        "    transforms=[\n",
        "        # Create the model inputs/targets/loss_mask.\n",
        "        gm.data.Seq2SeqTask(\n",
        "            # Select which field from the dataset to use.\n",
        "            # https://www.tensorflow.org/datasets/catalog/mtnt\n",
        "            in_prompt='src',\n",
        "            in_response='dst',\n",
        "            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n",
        "            out_input='input',\n",
        "            out_target='target',\n",
        "            out_target_mask='loss_mask',\n",
        "            tokenizer=tokenizer,\n",
        "            # Padding parameters\n",
        "            max_length=200,\n",
        "            truncate=True,\n",
        "        ),\n",
        "    ],\n",
        ")\n",
        "\n",
        "ex = ds[0]\n",
        "\n",
        "treescope.show(ex)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AH7sBVr8dwht"
      },
      "source": [
        "We can decode an example from the batch to inspect the model input and check it is properly formatted:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "executionInfo": {
          "elapsed": 52,
          "status": "ok",
          "timestamp": 1741556702133,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "4krpPavGdwht",
        "outputId": "b2da5cd2-4231-4dfe-b591-138ff0d96b84"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u003cstart_of_turn\u003euser\n",
            "Is this a good place to ask about the ethnicity and intelligence debate?\u003cend_of_turn\u003e\n",
            "\u003cstart_of_turn\u003emodel\n",
            "Est-ce un bon endroit pour poser des questions sur le débat à propos de l'ethnicité et le renseignement ?\u003cend_of_turn\u003e\n"
          ]
        }
      ],
      "source": [
        "text = tokenizer.decode(ex['input'][0])\n",
        "\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V0nJX7uddwht"
      },
      "source": [
        "### Trainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P0lMcUindwht"
      },
      "source": [
        "We then create the trainer, reusing the `model`, `init_transform` and `optimizer` created above:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "executionInfo": {
          "elapsed": 54,
          "status": "ok",
          "timestamp": 1741556702316,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "XGj5dw_Gdwht"
      },
      "outputs": [],
      "source": [
        "trainer = kd.train.Trainer(\n",
        "    seed=42,  # The seed of enlightenment\n",
        "    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default\n",
        "    # Dataset\n",
        "    train_ds=ds,\n",
        "    # Model\n",
        "    model=model,\n",
        "    init_transform=init_transform,\n",
        "    # Training parameters\n",
        "    num_train_steps=500,\n",
        "    train_losses={\n",
        "        \"loss\": kd.losses.SoftmaxCrossEntropyWithIntLabels(\n",
        "            logits=\"preds.logits\",\n",
        "            labels=\"batch.target\",\n",
        "            mask=\"batch.loss_mask\",\n",
        "        ),\n",
        "    },\n",
        "    optimizer=optax.adafactor(learning_rate=0.005),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FT0plMVmdwht"
      },
      "source": [
        "Trainning can be launched with the `.train()` method.\n",
        "\n",
        "Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "height": 171,
          "referenced_widgets": [
            "d50e55815c444675a8e669e781035b34",
            "a8fb8f12675e450cba2901cc78b2981a",
            "e4593e87157347bb8a9df5936d8b73e5",
            "6361fa8889444bc8a45a00488493fee1",
            "1600c7d9b6c3448c925102245c5d4b10",
            "7894811af08543d29481b3bec06b736b",
            "9f743749221b4cb281a3d2514e695840",
            "65f711cd25734c838b1fbc24ef914c1e",
            "c25865c509ac4517a1efe17e7b2063a4",
            "18e141276be24e00875769cbee883832",
            "8c81664c18894eadaf847cd00983330b"
          ]
        },
        "executionInfo": {
          "elapsed": 304719,
          "status": "ok",
          "timestamp": 1741557007194,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "8UmU7u-Hdwht",
        "outputId": "375d18b6-e07b-4621-920a-7de5efe01488"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Configuring ...\n",
            "Initializing ...\n",
            "Disabling pygrain multi-processing (unsupported in colab).\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting training loop at step 0\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d50e55815c444675a8e669e781035b34",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "train:   0%|          | 0/501 [00:00\u003c?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "state, aux = trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jGzopkGIdwht"
      },
      "source": [
        "## Inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D1sgP7-jdwht"
      },
      "source": [
        "In order to infer the model, you have two options:\n",
        "\n",
        "1. simply evaluate the `QATWrapper`: that does not provide any memory footprint reduction\n",
        "2. use the `IntWrapper` as follows (only available for INT8 quantization)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "executionInfo": {
          "elapsed": 5006,
          "status": "ok",
          "timestamp": 1741557012334,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "OJf7ARvkdwht"
      },
      "outputs": [],
      "source": [
        "quantized_model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(tokens=\"batch.input\", text_only=True))\n",
        "quantized_params = peft.quantize(state.params, method=peft.QuantizationMethod.INT8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1qLNZ3xvfxIk"
      },
      "source": [
        "then evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 15473,
          "status": "ok",
          "timestamp": 1741557027937,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "zX9ntj71fxIk",
        "outputId": "b2f0e65d-03c3-49b8-92d8-8350ae7d5221"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'Je me sens bien !\u003cend_of_turn\u003e'"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sampler = gm.text.Sampler(\n",
        "    model=quantized_model,\n",
        "    params=quantized_params,\n",
        "    tokenizer=tokenizer,\n",
        ")\n",
        "\n",
        "prompt = \"\"\"\\\n",
        "\u003cstart_of_turn\u003euser\n",
        "I'm feeling happy!\u003cend_of_turn\u003e\n",
        "\u003cstart_of_turn\u003emodel\n",
        "\"\"\"\n",
        "\n",
        "sampler.sample(prompt, max_new_tokens=30)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8sxaHy2Zdwht"
      },
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {},
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "1600c7d9b6c3448c925102245c5d4b10": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "18e141276be24e00875769cbee883832": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6361fa8889444bc8a45a00488493fee1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_18e141276be24e00875769cbee883832",
            "placeholder": "​",
            "style": "IPY_MODEL_8c81664c18894eadaf847cd00983330b",
            "value": " 501/501 [04:27\u0026lt;00:00,  5.36s/it]"
          }
        },
        "65f711cd25734c838b1fbc24ef914c1e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7894811af08543d29481b3bec06b736b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8c81664c18894eadaf847cd00983330b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9f743749221b4cb281a3d2514e695840": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a8fb8f12675e450cba2901cc78b2981a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7894811af08543d29481b3bec06b736b",
            "placeholder": "​",
            "style": "IPY_MODEL_9f743749221b4cb281a3d2514e695840",
            "value": "train: 100%"
          }
        },
        "c25865c509ac4517a1efe17e7b2063a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d50e55815c444675a8e669e781035b34": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a8fb8f12675e450cba2901cc78b2981a",
              "IPY_MODEL_e4593e87157347bb8a9df5936d8b73e5",
              "IPY_MODEL_6361fa8889444bc8a45a00488493fee1"
            ],
            "layout": "IPY_MODEL_1600c7d9b6c3448c925102245c5d4b10"
          }
        },
        "e4593e87157347bb8a9df5936d8b73e5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_65f711cd25734c838b1fbc24ef914c1e",
            "max": 501,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c25865c509ac4517a1efe17e7b2063a4",
            "value": 501
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
